{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YoURrh11fbIc"
      },
      "source": [
        "# RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment\n",
        "\n",
        "This notebook beautifully showcases how RAFT can be leveraged to fine-tune a model.\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "Curious how this works? Read our [paper](https://arxiv.org/abs/2304.06767) to explore the intricacies of our innovative approach."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BzmCovNKkwbi"
      },
      "source": [
        "## Initial Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n7TI5hirlzn8"
      },
      "outputs": [],
      "source": [
        "#@title Install the required libs\n",
        "%pip install -U -qq git+https://github.com/huggingface/diffusers.git\n",
        "%pip install -q accelerate transformers ftfy bitsandbytes gradio natsort safetensors xformers datasets\n",
        "%pip install -qq \"ipywidgets>=7,<8\"\n",
        "!wget -q https://raw.githubusercontent.com/OptimalScale/LMFlow/main/experimental/RAFT-diffusion/train_text_to_image_lora.py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "fvCBZCnrqcX1"
      },
      "outputs": [],
      "source": [
        "#@title Install CLIP\n",
        "\n",
        "!pip install git+https://github.com/openai/CLIP.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "guDgmswnmW-4"
      },
      "outputs": [],
      "source": [
        "#@title Import required libraries\n",
        "import argparse\n",
        "import itertools\n",
        "import math\n",
        "import os\n",
        "import shutil\n",
        "from os.path import expanduser  # pylint: disable=import-outside-toplevel\n",
        "from urllib.request import urlretrieve  # pylint: disable=import-outside-toplevel\n",
        "from contextlib import nullcontext\n",
        "import random\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torch.utils.checkpoint\n",
        "from torch.utils.data import Dataset\n",
        "import concurrent\n",
        "import PIL\n",
        "from accelerate import Accelerator\n",
        "from accelerate.logging import get_logger\n",
        "from accelerate.utils import set_seed\n",
        "from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler\n",
        "from diffusers.optimization import get_scheduler\n",
        "from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\n",
        "from PIL import Image\n",
        "from torchvision import transforms\n",
        "from tqdm.auto import tqdm\n",
        "from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
        "import clip\n",
        "import bitsandbytes as bnb\n",
        "from torch.utils.data import DataLoader\n",
        "def image_grid(imgs, rows, cols):\n",
        "    assert len(imgs) == rows*cols\n",
        "\n",
        "    w, h = imgs[0].size\n",
        "    grid = Image.new('RGB', size=(cols*w, rows*h))\n",
        "    grid_w, grid_h = grid.size\n",
        "    \n",
        "    for i, img in enumerate(imgs):\n",
        "        grid.paste(img, box=(i%cols*w, i//cols*h))\n",
        "    return grid"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f4D64FI9pI38"
      },
      "source": [
        "## Loading Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7IryKE4wq0SZ",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Creating Dataloader\n",
        "\n",
        "prompts=['airplane','automobile','bird','deer','dog','cat','frog','horse','ship','truck']  # CIFAR labels\n",
        "prompts = pd.DataFrame({'prompts': prompts}) #converting prompts list into a pandas dataframe\n",
        "\n",
        "class CIFAR10Dataset():\n",
        "    def __init__(self):\n",
        "        global prompts\n",
        "        self.prompts=prompts.iloc[:,0]\n",
        "        \n",
        "    def __len__(self):\n",
        "        return len(self.prompts)\n",
        "    \n",
        "    def __getitem__(self,index):\n",
        "        return self.prompts.iloc[index]\n",
        "\n",
        "#@markdown Please mention the batch size.\n",
        "batch_size =5 #@param {type:\"integer\"}\n",
        "\n",
        "\n",
        "dataset = CIFAR10Dataset()\n",
        "finetune_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWH9vc1kvhvC"
      },
      "source": [
        "## Loading CLIP"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lJAguhs1d89L"
      },
      "outputs": [],
      "source": [
        "def get_aesthetic_model(clip_model=\"vit_l_14\"):\n",
        "    \"\"\"load the aethetic model\"\"\"\n",
        "    home = expanduser(\"~\")\n",
        "    cache_folder = home + \"/.cache/emb_reader\"\n",
        "    path_to_model = cache_folder + \"/sa_0_4_\"+clip_model+\"_linear.pth\"\n",
        "    if not os.path.exists(path_to_model):\n",
        "        os.makedirs(cache_folder, exist_ok=True)\n",
        "        url_model = (\n",
        "            \"https://github.com/LAION-AI/aesthetic-predictor/blob/main/sa_0_4_\"+clip_model+\"_linear.pth?raw=true\"\n",
        "        )\n",
        "        urlretrieve(url_model, path_to_model)\n",
        "    if clip_model == \"vit_l_14\":\n",
        "        m = torch.nn.Linear(768, 1)\n",
        "    elif clip_model == \"vit_b_32\":\n",
        "        m = torch.nn.Linear(512, 1)\n",
        "    else:\n",
        "        raise ValueError()\n",
        "    s = torch.load(path_to_model)\n",
        "    m.load_state_dict(s)\n",
        "    m.eval()\n",
        "    return m\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "amodel= get_aesthetic_model(clip_model=\"vit_l_14\").to(device)\n",
        "amodel.eval()\n",
        "\n",
        "model, preprocess = clip.load('ViT-L/14', device=device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0RPeQGHUzUZp"
      },
      "source": [
        "## Evaluating Aesthetic Score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s61Ljr9Sd89M"
      },
      "outputs": [],
      "source": [
        "def get_image_score(image):    #Evaluating Scores if images\n",
        "    images = preprocess(image).unsqueeze(0).to(device)\n",
        "    with torch.no_grad():\n",
        "        image_features= model.encode_image(images).to(device)\n",
        "        image_features /= image_features.norm(dim=-1, keepdim=True)\n",
        "        image_features=image_features.to(torch.float32)\n",
        "        prediction = amodel(image_features)\n",
        "        return(float(prediction))\n",
        "    \n",
        "def get_max_score(image_list,index,epoch=0):  #The get_max_score function will return prompt's image with the highest aesthetic score will be chosen for additional fine-tuning.\n",
        "    score_list=[]\n",
        "    for image in image_list:\n",
        "        score_list.append(get_image_score(image))\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    prompts.loc[index,f'Epoch{epoch} Scores']=max(score_list)\n",
        "    return [max(score_list),score_list.index(max(score_list))]\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ak1jArUL0eCi"
      },
      "source": [
        "##Parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jv6WYJos0iT5",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Settings for the model\n",
        "\n",
        "#@markdown All settings have been configured to achieve optimal output. Changing them is not advisable.\n",
        "\n",
        "#@markdown Enter value for `resolution`.\n",
        "resolution=256 #@param {type:\"integer\"}\n",
        "\n",
        "#@markdown Enter value for `num_images_per_prompt`.\n",
        "num_images_per_prompt=10 #@param {type:\"integer\"} \n",
        "\n",
        "#@markdown Enter value for `epochs`. \n",
        "epochs=10 #@param {type:\"integer\"} |"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7gFbnMaLd89N",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "# @title Setting Stable Diffusion pipeline\n",
        "model_id = \"runwayml/stable-diffusion-v1-5\"\n",
        "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)\n",
        "pipe.enable_xformers_memory_efficient_attention()\n",
        "torch.cuda.empty_cache()\n",
        "\n",
        "#@markdown Check the `set_progress_bar_config` option if you would like to hide the progress bar for image generation\n",
        "set_progress_bar_config= False #@param {type:\"boolean\"}\n",
        "pipe.set_progress_bar_config(disable=set_progress_bar_config) \n",
        "\n",
        "\n",
        "scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
        "pipe.scheduler = scheduler\n",
        "\n",
        "torch.cuda.empty_cache()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9U2P_PUN-5xX"
      },
      "source": [
        "##Finetuning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F-m6S9Sg-yS_",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Generating images on the pretrained model\n",
        "\n",
        "#@markdown Check the box to generate images using the pretrained model.\n",
        "generate_pretrained_model_images= True #@param {type:\"boolean\"}\n",
        "\n",
        "if generate_pretrained_model_images:\n",
        "  image_list=[]\n",
        "  for step, prompt_list in enumerate(finetune_dataloader):\n",
        "      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images \n",
        "      image_list+=image\n",
        "      torch.cuda.empty_cache()\n",
        "\n",
        "  grid = image_grid(image_list, len(prompts),num_images_per_prompt)\n",
        "  grid.save(\"pretrained.png\") \n",
        "  grid\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kPfHR4HQd89N"
      },
      "outputs": [],
      "source": [
        "#@title Run training\n",
        "\n",
        "os.environ['MODEL_NAME'] = model_id\n",
        "os.environ['OUTPUT_DIR'] = f\"./CustomModel/\"\n",
        "topk=8\n",
        "training_steps_per_epoch=topk*10\n",
        "os.environ['CHECKPOINTING_STEPS']=str(training_steps_per_epoch)\n",
        "os.environ['RESOLUTION']=str(resolution)\n",
        "os.environ['LEARNING_RATE']=str(9e-6)\n",
        "\n",
        "# remove old account directory\n",
        "try: \n",
        "    shutil.rmtree('./CustomModel')\n",
        "except:\n",
        "    pass\n",
        "try: \n",
        "    shutil.rmtree('./trainingdataset/imagefolder/')\n",
        "except:\n",
        "    pass\n",
        "\n",
        "model_id = \"runwayml/stable-diffusion-v1-5\"\n",
        "\n",
        "\n",
        "for epoch in range(epochs+1):\n",
        "  print(\"Epoch: \",epoch)\n",
        "  epoch=epoch\n",
        "  training_steps=str(training_steps_per_epoch*(epoch+1))\n",
        "  os.environ['TRAINING_STEPS']=training_steps\n",
        "  os.environ['TRAINING_DIR'] = f'./trainingdataset/imagefolder/{epoch}'\n",
        "\n",
        "  training_prompts=[]\n",
        "  prompts[f'Epoch{epoch} Scores']=np.nan\n",
        "\n",
        "  for step, prompt_list in enumerate(finetune_dataloader):\n",
        "    image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images\n",
        "    image_list=[]\n",
        "\n",
        "    for i in range(int(len(image)/num_images_per_prompt)):\n",
        "      image_list.append(image[i*num_images_per_prompt:(i+1)*num_images_per_prompt])\n",
        "    torch.cuda.empty_cache()\n",
        "    \n",
        "    with concurrent.futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
        "      step_list=[i for i in range(step*batch_size,(step+1)*batch_size)]\n",
        "      score_index=executor.map(get_max_score,image_list,step_list,[epoch for i in range(len(step_list))])\n",
        "\n",
        "    iterator=0\n",
        "    for max_scores in score_index:\n",
        "      training_prompts.append([max_scores[0],image_list[iterator][max_scores[1]],prompt_list[iterator]])\n",
        "      iterator+=1\n",
        "\n",
        "  training_prompts=[row[1:3] for row in sorted(training_prompts,key=lambda x: (x[0]),reverse=True)[:topk]]\n",
        "  training_prompts=pd.DataFrame(training_prompts)\n",
        "\n",
        "  if not os.path.exists(f\"./trainingdataset/imagefolder/{epoch}/train/\"):\n",
        "    os.makedirs(f\"./trainingdataset/imagefolder/{epoch}/train/\")\n",
        "  if not os.path.exists(f\"./CustomModel/\"):\n",
        "    os.makedirs(f\"./CustomModel/\")\n",
        "  for i in range(len(training_prompts)):\n",
        "    training_prompts.iloc[i,0].save(f'./trainingdataset/imagefolder/{epoch}/train/{i}.png')\n",
        "\n",
        "  training_prompts['file_name']=[f\"{i}.png\" for i in range(len(training_prompts))]\n",
        "  training_prompts.columns = ['0','text','file_name']\n",
        "  training_prompts.drop('0',axis=1,inplace=True)\n",
        "  training_prompts.to_csv(f'./trainingdataset/imagefolder/{epoch}/train/metadata.csv',index=False)\n",
        "  torch.cuda.empty_cache()\n",
        "\n",
        "  if epoch<epochs:\n",
        "    !accelerate launch --num_processes=1 --mixed_precision='fp16' --dynamo_backend='no' --num_machines=1 train_text_to_image_lora.py \\\n",
        "        --pretrained_model_name_or_path=$MODEL_NAME \\\n",
        "        --train_data_dir=$TRAINING_DIR \\\n",
        "        --resolution=$RESOLUTION \\\n",
        "        --train_batch_size=8 \\\n",
        "        --gradient_accumulation_steps=1 \\\n",
        "        --gradient_checkpointing \\\n",
        "        --max_grad_norm=1 \\\n",
        "        --mixed_precision=\"fp16\" \\\n",
        "        --max_train_steps=$TRAINING_STEPS \\\n",
        "        --learning_rate=$LEARNING_RATE \\\n",
        "        --lr_warmup_steps=0 \\\n",
        "        --enable_xformers_memory_efficient_attention \\\n",
        "        --dataloader_num_workers=5 \\\n",
        "        --output_dir=$OUTPUT_DIR \\\n",
        "        --lr_warmup_steps=0 \\\n",
        "        --seed=1234 \\\n",
        "        --checkpointing_steps=$CHECKPOINTING_STEPS \\\n",
        "        --resume_from_checkpoint=\"latest\" \\\n",
        "        --lr_scheduler='constant' \n",
        "  \n",
        "  pipe.unet.load_attn_procs(f'./CustomModel/')\n",
        "  torch.cuda.empty_cache()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rglR5r5gahMv"
      },
      "source": [
        "##Results\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kcf9aY6od89O",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Generating results on the fine-tuned model\n",
        "\n",
        "#@markdown Check the box to generate images using the fine-tuned model.\n",
        "generate_finetuned_model_images= True #@param {type:\"boolean\"}\n",
        "\n",
        "if generate_finetuned_model_images:\n",
        "  image_list=[]\n",
        "  pipe.unet.load_attn_procs('./CustomModel')\n",
        "  for step, prompt_list in enumerate(finetune_dataloader):\n",
        "      image=pipe(prompt_list,num_images_per_prompt=num_images_per_prompt,width=resolution,height=resolution).images \n",
        "      image_list+=image\n",
        "      torch.cuda.empty_cache()\n",
        "\n",
        "  grid = image_grid(image_list, len(prompts),num_images_per_prompt)\n",
        "  grid.save(\"trained.png\")\n",
        "  grid"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "BzmCovNKkwbi",
        "f4D64FI9pI38",
        "BWH9vc1kvhvC",
        "0RPeQGHUzUZp",
        "Ak1jArUL0eCi",
        "9U2P_PUN-5xX",
        "rglR5r5gahMv"
      ]
    },
    "gpuClass": "premium",
    "kernelspec": {
      "display_name": "deepanshu",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.16"
    },
    "vscode": {
      "interpreter": {
        "hash": "cd95ac8400f934ca97b7c7125945f5f2a4616fc88b7668f808354bfbb29c51b3"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
